package com.angel.mongodb.core;

import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import com.angel.mongodb.annotation.InitValue;
import com.angel.mongodb.annotation.MongoOperateLog;
import com.angel.mongodb.annotation.OperateType;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.data.domain.Sort;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @description: MongoDB通用Dao抽象实现
 * @author: gankench@gmail.com
 * @time: 4/9/21 6:57 PM
 */
public class MongoDaoSupport implements BaseMongoDAO {
    private static String ID = "id";
    @Autowired
    @Qualifier("mongoTemplate")
    protected MongoTemplate mongoTemplate;


    /**
     * 插入或更新
     *
     * @param object 对象
     */
    @Override
    @MongoOperateLog(operateType = OperateType.SAVE)
    public String saveOrUpdate(Object object) {

        Long time = System.currentTimeMillis();
        String id = (String) ReflectUtil.getFieldValue(object, ID);
        Object objectOrg = StrUtil.isNotEmpty(id) ? getMongoTemplate().findById(id, object.getClass()) : null;

        if (objectOrg == null) {
            // 插入
            // 设置插入时间
            if (ReflectUtil.getField(object.getClass(), "createTime") != null) {
                ReflectUtil.setFieldValue(object, "createTime", time);
            }
            if (ReflectUtil.getField(object.getClass(), "updateTime") != null) {
                ReflectUtil.setFieldValue(object, "updateTime", time);
            }
            // 设置默认值
            setDefaultVaule(object);
            // 去除id值
            ReflectUtil.setFieldValue(object, ID, null);

            mongoTemplate.save(object);
            id = (String) ReflectUtil.getFieldValue(object, ID);
        } else {
            // 更新
            Field[] fields = ReflectUtil.getFields(object.getClass());
            // 拷贝属性
            for (Field field : fields) {
                if ("serialversionuid".equals(field.getName().toLowerCase())) {
                    continue;
                }
                if (!field.getName().equals(ID) && ReflectUtil.getFieldValue(object, field) != null) {
                    ReflectUtil.setFieldValue(objectOrg, field, ReflectUtil.getFieldValue(object, field));
                }
            }

            // 设置更新时间
            if (ReflectUtil.getField(objectOrg.getClass(), "updateTime") != null) {
                ReflectUtil.setFieldValue(objectOrg, "updateTime", time);
            }
            mongoTemplate.save(objectOrg);
        }
        return id;
    }

    /**
     * 批量插入
     *
     * @param <T>
     * @param list 对象
     */
    @Override
    @MongoOperateLog(operateType = OperateType.SAVE)
    public <T> void insertAll(List<T> list) {
        Long time = System.currentTimeMillis();

        Map<String, Object> idMap = new HashMap<String, Object>();
        for (Object object : list) {
            if (ReflectUtil.getFieldValue(object, ID) != null) {
                String id = (String) ReflectUtil.getFieldValue(object, ID);
                Object objectOrg = StrUtil.isNotEmpty(id) ? getMongoTemplate().findById(id, object.getClass()) : null;
                idMap.put((String) ReflectUtil.getFieldValue(object, ID), objectOrg);
            }
        }

        for (Object object : list) {
            if (ReflectUtil.getFieldValue(object, ID) != null && idMap.get((String) ReflectUtil.getFieldValue(object, ID)) != null) {
                // 数据库里已有相同id, 去除id以便插入
                ReflectUtil.setFieldValue(object, ID, null);
            }
            // 设置插入时间
            if (ReflectUtil.getField(object.getClass(), "createTime") != null) {
                ReflectUtil.setFieldValue(object, "createTime", time);
            }
            if (ReflectUtil.getField(object.getClass(), "updateTime") != null) {
                ReflectUtil.setFieldValue(object, "updateTime", time);
            }
            // 设置默认值
            setDefaultVaule(object);
        }
//        logSave(list);
        mongoTemplate.insertAll(list);

    }

    /**
     * 根据id更新全部字段
     *
     * @param object 对象
     */
    @Override
    @MongoOperateLog(operateType = OperateType.SAVE)
    public void updateAllColumnById(Object object) {

        if (StrUtil.isEmpty((String) ReflectUtil.getFieldValue(object, ID))) {
            return;
        }

        Long time = System.currentTimeMillis();
        if (ReflectUtil.getField(object.getClass(), "updateTime") != null) {
            ReflectUtil.setFieldValue(object, "updateTime", time);
        }
//        logSave(object);
        mongoTemplate.save(object);

    }

    /**
     * 根据条件删除
     *
     * @param query 查询
     * @param clazz 类
     */
    @Override
    @MongoOperateLog(operateType = OperateType.DELETE)
    public Long deleteByQuery(Query query, Class<?> clazz) {
//        logDelete(clazz, query);
        return mongoTemplate.remove(query, clazz).getDeletedCount();
    }

    /**
     * 更新查到的第一项
     *
     * @param query  查询
     * @param update 更新
     * @param clazz  类
     */
    @Override
    @MongoOperateLog(operateType = OperateType.UPDATE)
    public Long updateFirst(Query query, Update update, Class<?> clazz) {
//        logUpdate(clazz, query, update, false);
        return mongoTemplate.updateFirst(query, update, clazz).getModifiedCount();
    }

    /**
     * 更新查到的全部项
     *
     * @param query  查询
     * @param update 更新
     * @param clazz  类
     */
    @Override
    @MongoOperateLog(operateType = OperateType.UPDATE)
    public Long updateMulti(Query query, Update update, Class<?> clazz) {
//        logUpdate(clazz, query, update, true);
        return mongoTemplate.updateMulti(query, update, clazz).getModifiedCount();
    }


    /**
     * 根据id查找
     *
     * @param id    id
     * @param clazz 类
     * @return T 对象
     */
    @Override
    @MongoOperateLog(operateType = OperateType.QUERY)
    public <T> T findById(String id, Class<T> clazz) {
        if (StrUtil.isEmpty(id)) {
            return null;
        }
//        logQuery(clazz, new Query(Criteria.where(ID).is(id)));
        return (T) getMongoTemplate().findById(id, clazz);
    }

    /**
     * 根据条件查找单个
     *
     * @param query 查询
     * @param clazz 类
     * @return T 对象
     */
    @Override
    @MongoOperateLog(operateType = OperateType.QUERY)
    public <T> T findOneByQuery(Query query, Class<T> clazz) {
        query.limit(1);
//        logQuery(clazz, query);
        return (T) getMongoTemplate().findOne(query, clazz);
    }

    /**
     * 根据条件查找List
     *
     * @param <T>   类型
     * @param query 查询
     * @param clazz 类
     * @return List 列表
     */
    @Override
    @MongoOperateLog(operateType = OperateType.QUERY)
    public <T> List<T> findListByQuery(Query query, Class<T> clazz) {
        if (!query.isSorted()) {
            query.with(Sort.by(Sort.Direction.DESC, ID));
        }
//        logQuery(clazz, query);
        return (List<T>) getMongoTemplate().find(query, clazz);
    }

    /**
     * 根据条件查找id
     *
     * @param query 查询
     * @param clazz 类
     * @return List 列表
     */
    @Override
    @MongoOperateLog(operateType = OperateType.QUERY)
    public List<String> findIdsByQuery(Query query, Class<?> clazz) {
        List<String> ids = new ArrayList<String>();
        query.fields().include(ID);
//        logQuery(clazz, query);
        List<?> list = getMongoTemplate().find(query, clazz);
        for (Object object : list) {
            ids.add((String) ReflectUtil.getFieldValue(object, ID));
        }

        return ids;
    }

    /**
     * 查找数量
     *
     * @param query 查询
     * @param clazz 类
     * @return Long 数量
     */
    @Override
    @MongoOperateLog(operateType = OperateType.COUNT)
    public Long findCountByQuery(Query query, Class<?> clazz) {
//        logCount(clazz, query);
        return getMongoTemplate().count(query, clazz);
    }


    /**
     * 获取MongoDB模板操作
     *
     * @return
     */
    @Override
    public MongoTemplate getMongoTemplate() {
        return mongoTemplate;
    }

    /**
     * 设置默认值
     *
     * @param object 对象
     */
    private void setDefaultVaule(Object object) {
        Field[] fields = ReflectUtil.getFields(object.getClass());
        for (Field field : fields) {
            // 获取注解
            if (field.isAnnotationPresent(InitValue.class)) {
                InitValue defaultValue = field.getAnnotation(InitValue.class);

                String value = defaultValue.value();

                if (ReflectUtil.getFieldValue(object, field) == null) {
                    // 获取字段类型
                    Class<?> type = field.getType();
                    if (type.equals(String.class)) {
                        ReflectUtil.setFieldValue(object, field, value);
                    }
                    if (type.equals(Short.class)) {
                        ReflectUtil.setFieldValue(object, field, Short.parseShort(value));
                    }
                    if (type.equals(Integer.class)) {
                        ReflectUtil.setFieldValue(object, field, Integer.parseInt(value));
                    }
                    if (type.equals(Long.class)) {
                        ReflectUtil.setFieldValue(object, field, Long.parseLong(value));
                    }
                    if (type.equals(Float.class)) {
                        ReflectUtil.setFieldValue(object, field, Float.parseFloat(value));
                    }
                    if (type.equals(Double.class)) {
                        ReflectUtil.setFieldValue(object, field, Double.parseDouble(value));
                    }
                    if (type.equals(Boolean.class)) {
                        ReflectUtil.setFieldValue(object, field, Boolean.parseBoolean(value));
                    }
                }
            }
        }
    }
}
